-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Lowering for FlexAttention Backwards #125515
Add Lowering for FlexAttention Backwards #125515
Conversation
8333271
to
b725537
Compare
aa5cfda
to
9d10e81
Compare
9d10e81
to
a317d9d
Compare
@@ -320,6 +328,8 @@ def store_output( | |||
indices: Union[List[Any], Tuple[Any]], | |||
val: str, | |||
mask: Optional[str] = None, | |||
indent_width: int = 4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to figure out this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want indent width, but probably definitely not fake_output lol
d996891
to
50a2da9
Compare
50a2da9
to
8159d4a
Compare
@drisspg your PR has been successfully reverted. |
@huydhn hmmm so what exactly do I need to do? |
If you do mean to use >6GB of memory, do something like #126399
We're using 3 for sm86 |
c7968cd
to
69e8b7a
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot -h |
PyTorchBot Help
Merge
Revert
Rebase
Label
Dr CI
cherry-pick
Closeusage: @pytorchbot close Close a PR [Can be used on issues] |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: Lint / lintrunner-noclang / linux-job Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge -f "spoke with Catherine, I added all of flex's test to the serial list, and these failures look unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: pytorch#123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | <details> <summary>Full Data</summary> | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 | </details> Pull Request resolved: pytorch#125515 Approved by: https://github.com/Chillee
This reverts commit 95b9e98. Reverted pytorch#125515 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the newly added test runs out of memory https://hud.pytorch.org/pytorch/pytorch/commit/95b9e981c3ab68fc17f78b8a6bbfd9569745ae4c ([comment](pytorch#125515 (comment)))
# Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: pytorch#123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | <details> <summary>Full Data</summary> | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 | </details> Pull Request resolved: pytorch#125515 Approved by: https://github.com/Chillee
Summary
What does this PR do?
It enables Inductor to actually generate the fused flex attention kernel for the backwards
I did some other things along the way:
TritonTemplateKernel
to actually accept multiple subgraphs (modifications)Benchmark Numbers:
The current implementation is not parallelizing over ctx length in the bwd
FWD Speedups
BWD Speedups
Full Data
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang